{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Transformers 설치 방법\n",
    "! pip install transformers datasets evaluate accelerate\n",
    "# 마지막 릴리스 대신 소스에서 설치하려면, 위 명령을 주석으로 바꾸고 아래 명령을 해제하세요.\n",
    "# ! pip install git+https://github.com/huggingface/transformers.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 시각적 질의응답 (Visual Question Answering)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "시각적 질의응답(VQA)은 이미지를 기반으로 개방형 질문에 대응하는 작업입니다. 이 작업을 지원하는 모델의 입력은 대부분 이미지와 질문의 조합이며, 출력은 자연어로 된 답변입니다.\n",
    "\n",
    "VQA의 주요 사용 사례는 다음과 같습니다:\n",
    "* 시각 장애인을 위한 접근성 애플리케이션을 구축할 수 있습니다.\n",
    "* 교육: 강의나 교과서에 나온 시각 자료에 대한 질문에 답할 수 있습니다. 또한 체험형 전시와 유적 등에서도 VQA를 활용할 수 있습니다.\n",
    "* 고객 서비스 및 전자상거래: VQA는 사용자가 제품에 대해 질문할 수 있게 함으로써 사용자 경험을 향상시킬 수 있습니다.\n",
    "* 이미지 검색: VQA 모델을 사용하여 원하는 특성을 가진 이미지를 검색할 수 있습니다. 예를 들어 사용자는 \"강아지가 있어?\"라고 물어봐서 주어진 이미지 묶음에서 강아지가 있는 모든 이미지를 받아볼 수 있습니다.\n",
    "\n",
    "이 가이드에서 학습할 내용은 다음과 같습니다:\n",
    "\n",
    "- VQA 모델 중 하나인 [ViLT](https://huggingface.co/docs/transformers/main/ko/tasks/../../en/model_doc/vilt)를 [`Graphcore/vqa` 데이터셋](https://huggingface.co/datasets/Graphcore/vqa) 에서 미세조정하는 방법\n",
    "- 미세조정된 ViLT 모델로 추론하는 방법\n",
    "- BLIP-2 같은 생성 모델로 제로샷 VQA 추론을 실행하는 방법"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ViLT 미세 조정 [[finetuning-vilt]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ViLT는 Vision Transformer (ViT) 내에 텍스트 임베딩을 포함하여 비전/자연어 사전훈련(VLP; Vision-and-Language Pretraining)을 위한 기본 디자인을 제공합니다.\n",
    "ViLT 모델은 비전 트랜스포머(ViT)에 텍스트 임베딩을 넣어 비전/언어 사전훈련(VLP; Vision-and-Language Pre-training)을 위한 기본적인 디자인을 갖췄습니다. 이 모델은 여러 다운스트림 작업에 사용할 수 있습니다. VQA 태스크에서는 (`[CLS]` 토큰의 최종 은닉 상태 위에 선형 레이어인) 분류 헤더가 있으며 무작위로 초기화됩니다.\n",
    "따라서 여기에서 시각적 질의응답은 **분류 문제**로 취급됩니다.\n",
    "\n",
    "최근의 BLIP, BLIP-2, InstructBLIP와 같은 모델들은 VQA를 생성형 작업으로 간주합니다. 가이드의 후반부에서는 이런 모델들을 사용하여 제로샷 VQA 추론을 하는 방법에 대해 설명하겠습니다.\n",
    "\n",
    "시작하기 전 필요한 모든 라이브러리를 설치했는지 확인하세요.\n",
    "\n",
    "```bash\n",
    "pip install -q transformers datasets\n",
    "```\n",
    "\n",
    "커뮤니티에 모델을 공유하는 것을 권장 드립니다. Hugging Face 계정에 로그인하여 🤗 Hub에 업로드할 수 있습니다.\n",
    "메시지가 나타나면 로그인할 토큰을 입력하세요:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "모델 체크포인트를 전역 변수로 선언하세요."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_checkpoint = \"dandelin/vilt-b32-mlm\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 데이터 가져오기 [[load-the-data]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이 가이드에서는 `Graphcore/vqa` 데이터세트의 작은 샘플을 사용합니다. 전체 데이터세트는 [🤗 Hub](https://huggingface.co/datasets/Graphcore/vqa) 에서 확인할 수 있습니다.\n",
    "\n",
    "[`Graphcore/vqa` 데이터세트](https://huggingface.co/datasets/Graphcore/vqa) 의 대안으로 공식 [VQA 데이터세트 페이지](https://visualqa.org/download.html) 에서 동일한 데이터를 수동으로 다운로드할 수 있습니다. 직접 공수한 데이터로 튜토리얼을 따르고 싶다면 [이미지 데이터세트 만들기](https://huggingface.co/docs/datasets/image_dataset#loading-script) 라는\n",
    "🤗 Datasets 문서를 참조하세요.\n",
    "\n",
    "검증 데이터의 첫 200개 항목을 불러와 데이터세트의 특성을 확인해 보겠습니다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['question', 'question_type', 'question_id', 'image_id', 'answer_type', 'label'],\n",
       "    num_rows: 200\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"Graphcore/vqa\", split=\"validation[:200]\")\n",
    "dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "예제를 하나 뽑아 데이터세트의 특성을 이해해 보겠습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'question': 'Where is he looking?',\n",
       " 'question_type': 'none of the above',\n",
       " 'question_id': 262148000,\n",
       " 'image_id': '/root/.cache/huggingface/datasets/downloads/extracted/ca733e0e000fb2d7a09fbcc94dbfe7b5a30750681d0e965f8e0a23b1c2f98c75/val2014/COCO_val2014_000000262148.jpg',\n",
       " 'answer_type': 'other',\n",
       " 'label': {'ids': ['at table', 'down', 'skateboard', 'table'],\n",
       "  'weights': [0.30000001192092896,\n",
       "   1.0,\n",
       "   0.30000001192092896,\n",
       "   0.30000001192092896]}}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "데이터세트에는 다음과 같은 특성이 포함되어 있습니다:\n",
    "* `question`: 이미지에 대한 질문\n",
    "* `image_id`: 질문과 관련된 이미지의 경로\n",
    "* `label`: 데이터의 레이블 (annotations)\n",
    "\n",
    "나머지 특성들은 필요하지 않기 때문에 삭제해도 됩니다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = dataset.remove_columns(['question_type', 'question_id', 'answer_type'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "보시다시피 `label` 특성은 같은 질문마다 답변이 여러 개 있을 수 있습니다. 모두 다른 데이터 라벨러들로부터 수집되었기 때문인데요. 질문의 답변은 주관적일 수 있습니다. 이 경우 질문은 \"그는 어디를 보고 있나요?\" 였지만, 어떤 사람들은 \"아래\"로 레이블을 달았고, 다른 사람들은 \"테이블\" 또는 \"스케이트보드\" 등으로 주석을 달았습니다.\n",
    "\n",
    "아래의 이미지를 보고 어떤 답변을 선택할 것인지 생각해 보세요:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "\n",
    "image = Image.open(dataset[0]['image_id'])\n",
    "image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "     <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/vqa-example.png\" alt=\"VQA Image Example\"/>\n",
    "</div>\n",
    "\n",
    "질문과 답변의 모호성으로 인해 이러한 데이터세트는 여러 개의 답변이 가능하므로 다중 레이블 분류 문제로 처리됩니다. 게다가, 원핫(one-hot) 인코딩 벡터를 생성하기보다는 레이블에서 특정 답변이 나타나는 횟수를 기반으로 소프트 인코딩을 생성합니다.\n",
    "\n",
    "위의 예시에서 \"아래\"라는 답변이 다른 답변보다 훨씬 더 자주 선택되었기 때문에 데이터세트에서 `weight`라고 불리는 점수로 1.0을 가지며, 나머지 답변들은 1.0 미만의 점수를 가집니다.\n",
    "\n",
    "적절한 분류 헤더로 모델을 나중에 인스턴스화하기 위해 레이블을 정수로 매핑한 딕셔너리 하나, 반대로 정수를 레이블로 매핑한 딕셔너리 하나 총 2개의 딕셔너리를 생성하세요:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "\n",
    "labels = [item['ids'] for item in dataset['label']]\n",
    "flattened_labels = list(itertools.chain(*labels))\n",
    "unique_labels = list(set(flattened_labels))\n",
    "\n",
    "label2id = {label: idx for idx, label in enumerate(unique_labels)}\n",
    "id2label = {idx: label for label, idx in label2id.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이제 매핑이 완료되었으므로 문자열 답변을 해당 id로 교체하고, 데이터세트의 더 편리한 후처리를 위해 편평화 할 수 있습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'question': Value(dtype='string', id=None),\n",
       " 'image_id': Value(dtype='string', id=None),\n",
       " 'label.ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),\n",
       " 'label.weights': Sequence(feature=Value(dtype='float64', id=None), length=-1, id=None)}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def replace_ids(inputs):\n",
    "  inputs[\"label\"][\"ids\"] = [label2id[x] for x in inputs[\"label\"][\"ids\"]]\n",
    "  return inputs\n",
    "\n",
    "\n",
    "dataset = dataset.map(replace_ids)\n",
    "flat_dataset = dataset.flatten()\n",
    "flat_dataset.features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 데이터 전처리 [[preprocessing-data]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "다음 단계는 모델을 위해 이미지와 텍스트 데이터를 준비하기 위해 ViLT 프로세서를 가져오는 것입니다.\n",
    "`ViltProcessor`는 BERT 토크나이저와 ViLT 이미지 프로세서를 편리하게 하나의 프로세서로 묶습니다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import ViltProcessor\n",
    "\n",
    "processor = ViltProcessor.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "데이터를 전처리하려면 이미지와 질문을 `ViltProcessor`로 인코딩해야 합니다. 프로세서는 [BertTokenizerFast](https://huggingface.co/docs/transformers/main/ko/model_doc/electra#transformers.BertTokenizer)로 텍스트를 토크나이즈하고 텍스트 데이터를 위해 `input_ids`, `attention_mask` 및 `token_type_ids`를 생성합니다.\n",
    "이미지는 `ViltImageProcessor`로 이미지를 크기 조정하고 정규화하며, `pixel_values`와 `pixel_mask`를 생성합니다.\n",
    "\n",
    "이런 전처리 단계는 모두 내부에서 이루어지므로, `processor`를 호출하기만 하면 됩니다. 하지만 아직 타겟 레이블이 완성되지 않았습니다. 타겟의 표현에서 각 요소는 가능한 답변(레이블)에 해당합니다. 정확한 답변의 요소는 해당 점수(weight)를 유지시키고 나머지 요소는 0으로 설정해야 합니다.\n",
    "\n",
    "아래 함수가 위에서 설명한대로 이미지와 질문에 `processor`를 적용하고 레이블을 형식에 맞춥니다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def preprocess_data(examples):\n",
    "    image_paths = examples['image_id']\n",
    "    images = [Image.open(image_path) for image_path in image_paths]\n",
    "    texts = examples['question']\n",
    "\n",
    "    encoding = processor(images, texts, padding=\"max_length\", truncation=True, return_tensors=\"pt\")\n",
    "\n",
    "    for k, v in encoding.items():\n",
    "          encoding[k] = v.squeeze()\n",
    "\n",
    "    targets = []\n",
    "\n",
    "    for labels, scores in zip(examples['label.ids'], examples['label.weights']):\n",
    "        target = torch.zeros(len(id2label))\n",
    "\n",
    "        for label, score in zip(labels, scores):\n",
    "            target[label] = score\n",
    "\n",
    "        targets.append(target)\n",
    "\n",
    "    encoding[\"labels\"] = targets\n",
    "\n",
    "    return encoding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "전체 데이터세트에 전처리 함수를 적용하려면 🤗 Datasets의 `map` 함수를 사용하십시오. `batched=True`를 설정하여 데이터세트의 여러 요소를 한 번에 처리함으로써 `map`을 더 빠르게 할 수 있습니다. 이 시점에서 필요하지 않은 열은 제거하세요."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['input_ids', 'token_type_ids', 'attention_mask', 'pixel_values', 'pixel_mask', 'labels'],\n",
       "    num_rows: 200\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "processed_dataset = flat_dataset.map(preprocess_data, batched=True, remove_columns=['question','question_type',  'question_id', 'image_id', 'answer_type', 'label.ids', 'label.weights'])\n",
    "processed_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "마지막 단계로, [DefaultDataCollator](https://huggingface.co/docs/transformers/main/ko/main_classes/data_collator#transformers.DefaultDataCollator)를 사용하여 예제로 쓸 배치를 생성하세요:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DefaultDataCollator\n",
    "\n",
    "data_collator = DefaultDataCollator()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 모델 훈련 [[train-the-model]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이제 모델을 훈련하기 위해 준비되었습니다! `ViltForQuestionAnswering`으로 ViLT를 가져올 차례입니다. 레이블의 수와 레이블 매핑을 지정하세요:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import ViltForQuestionAnswering\n",
    "\n",
    "model = ViltForQuestionAnswering.from_pretrained(model_checkpoint, num_labels=len(id2label), id2label=id2label, label2id=label2id)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이 시점에서는 다음 세 단계만 남았습니다:\n",
    "\n",
    "1. [TrainingArguments](https://huggingface.co/docs/transformers/main/ko/main_classes/trainer#transformers.TrainingArguments)에서 훈련 하이퍼파라미터를 정의하세요:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments\n",
    "\n",
    "repo_id = \"MariaK/vilt_finetuned_200\"\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=repo_id,\n",
    "    per_device_train_batch_size=4,\n",
    "    num_train_epochs=20,\n",
    "    save_steps=200,\n",
    "    logging_steps=50,\n",
    "    learning_rate=5e-5,\n",
    "    save_total_limit=2,\n",
    "    remove_unused_columns=False,\n",
    "    push_to_hub=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. 모델, 데이터세트, 프로세서, 데이터 콜레이터와 함께 훈련 인수를 [Trainer](https://huggingface.co/docs/transformers/main/ko/main_classes/trainer#transformers.Trainer)에 전달하세요:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import Trainer\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    data_collator=data_collator,\n",
    "    train_dataset=processed_dataset,\n",
    "    processing_class=processor,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3. [train()](https://huggingface.co/docs/transformers/main/ko/main_classes/trainer#transformers.Trainer.train)을 호출하여 모델을 미세 조정하세요:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "훈련이 완료되면, [push_to_hub()](https://huggingface.co/docs/transformers/main/ko/main_classes/trainer#transformers.Trainer.push_to_hub) 메소드를 사용하여 🤗 Hub에 모델을 공유하세요:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.push_to_hub()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 추론 [[inference]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ViLT 모델을 미세 조정하고 🤗 Hub에 업로드했다면 추론에 사용할 수 있습니다. 미세 조정된 모델을 추론에 사용해보는 가장 간단한 방법은 [Pipeline](https://huggingface.co/docs/transformers/main/ko/main_classes/pipelines#transformers.Pipeline)에서 사용하는 것입니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "pipe = pipeline(\"visual-question-answering\", model=\"MariaK/vilt_finetuned_200\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이 가이드의 모델은 200개의 예제에서만 훈련되었으므로 그다지 많은 것을 기대할 수는 없습니다. 데이터세트의 첫 번째 예제를 사용하여 추론 결과를 설명해보겠습니다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"Where is he looking?\"\n",
       "[{'score': 0.5498199462890625, 'answer': 'down'}]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "example = dataset[0]\n",
    "image = Image.open(example['image_id'])\n",
    "question = example['question']\n",
    "print(question)\n",
    "pipe(image, question, top_k=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "비록 확신은 별로 없지만, 모델은 실제로 무언가를 배웠습니다. 더 많은 예제와 더 긴 훈련 기간이 주어진다면 분명 더 나은 결과를 얻을 수 있을 것입니다!\n",
    "\n",
    "원한다면 파이프라인의 결과를 수동으로 복제할 수도 있습니다:\n",
    "1. 이미지와 질문을 가져와서 프로세서를 사용하여 모델에 준비합니다.\n",
    "2. 전처리된 결과를 모델에 전달합니다.\n",
    "3. 로짓에서 가장 가능성 있는 답변의 id를 가져와서 `id2label`에서 실제 답변을 찾습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Predicted answer: down"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "processor = ViltProcessor.from_pretrained(\"MariaK/vilt_finetuned_200\")\n",
    "\n",
    "image = Image.open(example['image_id'])\n",
    "question = example['question']\n",
    "\n",
    "# prepare inputs\n",
    "inputs = processor(image, question, return_tensors=\"pt\")\n",
    "\n",
    "model = ViltForQuestionAnswering.from_pretrained(\"MariaK/vilt_finetuned_200\")\n",
    "\n",
    "# forward pass\n",
    "with torch.no_grad():\n",
    "    outputs = model(**inputs)\n",
    "\n",
    "logits = outputs.logits\n",
    "idx = logits.argmax(-1).item()\n",
    "print(\"Predicted answer:\", model.config.id2label[idx])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 제로샷 VQA [[zeroshot-vqa]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이전 모델은 VQA를 분류 문제로 처리했습니다. BLIP, BLIP-2 및 InstructBLIP와 같은 최근의 모델은 VQA를 생성 작업으로 접근합니다. [BLIP-2](https://huggingface.co/docs/transformers/main/ko/tasks/../../en/model_doc/blip-2)를 예로 들어 보겠습니다. 이 모델은 사전훈련된 비전 인코더와 LLM의 모든 조합을 사용할 수 있는 새로운 비전-자연어 사전 학습 패러다임을 도입했습니다. ([BLIP-2 블로그 포스트](https://huggingface.co/blog/blip-2)를 통해 더 자세히 알아볼 수 있어요)\n",
    "이를 통해 시각적 질의응답을 포함한 여러 비전-자연어 작업에서 SOTA를 달성할 수 있었습니다.\n",
    "\n",
    "이 모델을 어떻게 VQA에 사용할 수 있는지 설명해 보겠습니다. 먼저 모델을 가져와 보겠습니다. 여기서 GPU가 사용 가능한 경우 모델을 명시적으로 GPU로 전송할 것입니다. 이전에는 훈련할 때 쓰지 않은 이유는 [Trainer](https://huggingface.co/docs/transformers/main/ko/main_classes/trainer#transformers.Trainer)가 이 부분을 자동으로 처리하기 때문입니다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoProcessor, Blip2ForConditionalGeneration\n",
    "import torch\n",
    "\n",
    "processor = AutoProcessor.from_pretrained(\"Salesforce/blip2-opt-2.7b\")\n",
    "model = Blip2ForConditionalGeneration.from_pretrained(\"Salesforce/blip2-opt-2.7b\", dtype=torch.float16)\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "모델은 이미지와 텍스트를 입력으로 받으므로, VQA 데이터세트의 첫 번째 예제에서와 동일한 이미지/질문 쌍을 사용해 보겠습니다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "example = dataset[0]\n",
    "image = Image.open(example['image_id'])\n",
    "question = example['question']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "BLIP-2를 시각적 질의응답 작업에 사용하려면 텍스트 프롬프트가 `Question: {} Answer:` 형식을 따라야 합니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = f\"Question: {question} Answer:\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이제 모델의 프로세서로 이미지/프롬프트를 전처리하고, 처리된 입력을 모델을 통해 전달하고, 출력을 디코드해야 합니다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"He is looking at the crowd\""
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs = processor(image, text=prompt, return_tensors=\"pt\").to(device, torch.float16)\n",
    "\n",
    "generated_ids = model.generate(**inputs, max_new_tokens=10)\n",
    "generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()\n",
    "print(generated_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "보시다시피 모델은 군중을 인식하고, 얼굴의 방향(아래쪽을 보고 있음)을 인식했지만, 군중이 스케이터 뒤에 있다는 사실을 놓쳤습니다. 그러나 사람이 직접 라벨링한 데이터셋을 얻을 수 없는 경우에, 이 접근법은 빠르게 유용한 결과를 생성할 수 있습니다."
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
